{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning: algorithms, objectives, and assumptions\n",
    "(c) Deniz Yuret 2019\n",
    "\n",
    "In this notebook we will analyze three classic learning algorithms.\n",
    "* **Perceptron:** ([Rosenblatt, 1957](https://en.wikipedia.org/wiki/Perceptron)) a neuron model trained with a simple algorithm that updates model weights using the input when the prediction is wrong.\n",
    "* **Adaline:** ([Widrow and Hoff, 1960](https://en.wikipedia.org/wiki/ADALINE)) a neuron model trained with a simple algorithm that updates model weights using the error multiplied by the input (aka least mean square (LMS), delta learning rule, or the Widrow-Hoff rule).\n",
    "* **Softmax classification:** ([Cox, 1958](https://en.wikipedia.org/wiki/Multinomial_logistic_regression)) a multiclass generalization of the logistic regression model from statistics (aka multinomial LR, softmax regression, maxent classifier etc.).\n",
    "\n",
    "We will look at these learners from three different perspectives:\n",
    "* **Algorithm:** First we ask only **how** the learner works, i.e. how it changes after observing each example.\n",
    "* **Objectives:** Next we ask **what** objective guides the algorithm, whether it is optimizing a particular objective function, and whether we can use a generic *optimization algorithm* instead.\n",
    "* **Assumptions:** Finally we ask **why** we think this algorithm makes sense, what prior assumptions does this imply and whether we can use *probabilistic inference* for optimal learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "using Knet, Plots, Statistics, LinearAlgebra, Random\n",
    "Base.argmax(a::KnetArray) = argmax(Array(a))\n",
    "Base.argmax(a::AutoGrad.Value) = argmax(value(a))\n",
    "ENV[\"COLUMNS\"] = 72"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include(Knet.dir(\"data/mnist.jl\"))\n",
    "xtrn, ytrn, xtst, ytst = mnist()\n",
    "ARRAY = Array{Float32}\n",
    "xtrn, xtst = ARRAY(mat(xtrn)), ARRAY(mat(xtst))\n",
    "onehot(y) = (m=ARRAY(zeros(maximum(y),length(y))); for i in 1:length(y); m[y[i],i]=1; end; m)\n",
    "ytrn, ytst = onehot(ytrn), onehot(ytst);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "println.(summary.((xtrn, ytrn, xtst, ytst)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NTRN,NTST,XDIM,YDIM = size(xtrn,2), size(xtst,2), size(xtrn,1), size(ytrn,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model weights\n",
    "w = ARRAY(randn(YDIM,XDIM))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class scores\n",
    "w * xtrn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Predictions\n",
    "[ argmax(w * xtrn[:,i]) for i in 1:NTRN ]'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Correct answers\n",
    "[ argmax(ytrn[:,i]) for i in 1:NTRN ]'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Accuracy\n",
    "acc(w,x,y) = mean(argmax(w * x, dims=1) .== argmax(y, dims=1))\n",
    "acc(w,xtrn,ytrn), acc(w,xtst,ytst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training loop\n",
    "function train(algo,x,y,T=2^20)\n",
    "    w = ARRAY(zeros(size(y,1),size(x,1)))\n",
    "    nexamples = size(x,2)\n",
    "    nextprint = 1\n",
    "    for t = 1:T\n",
    "        i = rand(1:nexamples)\n",
    "        algo(w, x[:,i], y[:,i])  # <== this is where w is updated\n",
    "        if t == nextprint\n",
    "            println((iter=t, accuracy=acc(w,x,y), wnorm=norm(w)))\n",
    "            nextprint = min(2t,T)\n",
    "        end\n",
    "    end\n",
    "    w\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perceptron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function perceptron(w,x,y)\n",
    "    guess = argmax(w * x)\n",
    "    class = argmax(y)\n",
    "    if guess != class\n",
    "        w[class,:] .+= x\n",
    "        w[guess,:] .-= x\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (iter = 1048576, accuracy = 0.8950333333333333, wnorm = 1321.2463f0) in 7 secs\n",
    "@time wperceptron = train(perceptron,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adaline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function adaline(w,x,y; lr=0.0001)\n",
    "    error = w * x - y\n",
    "    w .-= lr * error * x'\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (iter = 1048576, accuracy = 0.8523, wnorm = 1.2907721f0) in 31 secs with lr=0.0001\n",
    "@time wadaline = train(adaline,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Softmax classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function softmax(w,x,y; lr=0.01)\n",
    "    probs = exp.(w * x)\n",
    "    probs ./= sum(probs)\n",
    "    error = probs - y\n",
    "    w .-= lr * error * x'\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (iter = 1048576, accuracy = 0.9242166666666667, wnorm = 26.523603f0) in 32 secs with lr=0.01\n",
    "@time wsoftmax = train(softmax,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Objectives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training via optimization\n",
    "function optimize(loss,x,y; lr=0.1, iters=2^20)\n",
    "    w = Param(ARRAY(zeros(size(y,1),size(x,1))))\n",
    "    nexamples = size(x,2)\n",
    "    nextprint = 1\n",
    "    for t = 1:iters\n",
    "        i = rand(1:nexamples)\n",
    "        L = @diff loss(w, x[:,i], y[:,i])\n",
    "        ∇w = grad(L,w)\n",
    "        w .-= lr * ∇w\n",
    "        if t == nextprint\n",
    "            println((iter=t, accuracy=acc(w,x,y), wnorm=norm(w)))\n",
    "            nextprint = min(2t,iters)\n",
    "        end\n",
    "    end\n",
    "    w\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perceptron minimizes the score difference between the correct class and the prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function perceptronloss(w,x,y)\n",
    "    score = w * x\n",
    "    guess = argmax(score)\n",
    "    class = argmax(y)\n",
    "    score[guess] - score[class]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (iter = 1048576, accuracy = 0.8908833333333334, wnorm = 1322.4888f0) in 62 secs with lr=1\n",
    "@time wperceptron2 = optimize(perceptronloss,xtrn,ytrn,lr=1);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adaline minimizes the squared error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function quadraticloss(w,x,y)\n",
    "    0.5 * sum(abs2, w * x - y)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (iter = 1048576, accuracy = 0.8498333333333333, wnorm = 1.2882874f0) in 79 secs with lr=0.0001\n",
    "@time wadaline2 = optimize(quadraticloss,xtrn,ytrn,lr=0.0001);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Softmax classifier maximizes the probabilities of correct answers\n",
    "(or minimizes negative log likelihood, aka cross-entropy or softmax loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function negloglik(w,x,y)\n",
    "    probs = exp.(w * x)\n",
    "    probs = probs / sum(probs)\n",
    "    class = argmax(y)\n",
    "    -log(probs[class])\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (iter = 1048576, accuracy = 0.9283833333333333, wnorm = 26.593485f0) in 120 secs with lr=0.01\n",
    "@time wsoftmax2 = optimize(negloglik,xtrn,ytrn,lr=0.01);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.1.0",
   "language": "julia",
   "name": "julia-1.1"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.1.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
